# coding=utf-8
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objs as go

# 读数据
aisles = pd.read_csv('./aisles.csv')
order_product = pd.read_csv('./order_products__prior.csv')
orders = pd.read_csv('./orders.csv')
products = pd.read_csv('./products.csv')

# 整合
table1 = pd.merge(order_product, products, on=["product_id", "product_id"])
table2 = pd.merge(table1, orders, on=["order_id", "order_id"])
table = pd.merge(table2, aisles, on=["aisle_id", "aisle_id"])

table.replace('?', value=np.nan, inplace=True)

# print(table.head())
# print(table.info())
'''
 0   order_id                int64  
 1   product_id              int64  
 2   add_to_cart_order       int64  
 3   reordered               int64  
 4   product_name            object 
 5   aisle_id                int64  
 6   department_id           int64  
 7   user_id                 int64  
 8   eval_set                object 
 9   order_number            int64  
 10  order_dow               int64  
 11  order_hour_of_day       int64  
 12  days_since_prior_order  float64
 13  aisle                   object 
 '''


def drwa_bar_of_sales():
    gr_pro = table.groupby(by='product_name')
    gr_pro_sortbyvalue = gr_pro['order_id'].count()
    top_10_products = gr_pro_sortbyvalue.sort_values(ascending=False).head(10)
    fig = go.Figure(data=go.Bar(x=top_10_products.index, y=top_10_products.values,
                                marker=dict(
                                    color=['rgb(205,38,38)', 'rgb(205,100,38)', 'rgb(34,139,34)', 'rgb(150,38,38)',
                                           'rgb(205,38,230)', 'rgb(205,20,38)', 'rgb(100,101,102)', 'rgb(150,60,38)'
                                        , 'rgb(18,46,60)', 'rgb(90,38,80)'])
                                ))
    fig.update_layout(title='Top 10 of sales')
    fig.show()


def order_bull():
    # 找出有钱的客户
    order_number = table.groupby(by='user_id')['order_number'].sum()
    # 找到订单大牛，订单编号 201268、164055、176478
    # print(order_number.sort_values(ascending=False))
    order_bull = table.groupby(by='user_id')
    for i, j in order_bull:
        if i == 201268:
            TOP1_bull = j
    # print(TOP1_bull.info())
    TOP1_bull = TOP1_bull.groupby(by='product_name')
    labels = []
    values = []
    for i, j in TOP1_bull:
        if j['order_number'].sum() > 2000:
            labels.append(i)
            values.append(j['order_number'].sum())
    # print(labels, values)
    fig = go.Figure(data=[go.Pie(labels=labels,
                                 values=values, hole=.3,
                                 title='Shopping distribute of TOP1_BULL')])
    fig.show()


if __name__ == '__main__':
    # drwa_bar_of_sales()
    order_bull()